/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file conv3d_bp_tiling_util.h
 * \brief
 */

#ifndef TILING_CONV_BACKPROP_CONV3D_BP_TILING_UTIL_H
#define TILING_CONV_BACKPROP_CONV3D_BP_TILING_UTIL_H

#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <cstddef>

namespace ConvBackpropApi {

constexpr uint32_t DB_ON = 2;
constexpr uint64_t L1_SIZE = 524160; // 512 * 1024 - 128;
constexpr int64_t L0_SIZE = 65536;
constexpr uint32_t BEST_BASE_M = 128;
constexpr uint32_t BEST_BASE_K = 128;
constexpr uint32_t BEST_BASE_N = 256;
constexpr int64_t WORKSIZE = 16777216; // 16 * 1024 * 1024;
constexpr int32_t BUFFER_NUM_L1 = 4;

// dw
constexpr uint32_t C04 = 4;
constexpr uint32_t C16 = 16;
constexpr int32_t MIN_BATCHDIM = 4;
constexpr uint32_t SECOND_BASE_M = 64;
constexpr uint32_t SECOND_BASE_K = 64;
constexpr uint32_t SECOND_BASE_N = 512;
constexpr uint32_t MIN_STEPK = 2;
constexpr uint32_t ROW_NUM = 2;

constexpr uint32_t C0_BYTE_SIZE = 32;
constexpr uint32_t LOAD3D_MAX_STRIDE_H_W = 63;
constexpr uint32_t LOAD3D_MAX_DILATION_H_W = 255;
constexpr uint32_t LOAD3D_MAX_PAD = 255;
constexpr uint32_t LOAD3D_MAX_FILTER_H_W = 511;
constexpr uint32_t LOAD3D_MAX_DDR2L1_SIZE = 65535;

constexpr uint32_t DB_OFF = 1;
constexpr const uint32_t ROW_FIRST = 1;
constexpr const uint32_t COL_FIRST = 2;

constexpr int32_t KERNEL_HW_4 = 4;
constexpr int32_t KERNEL_HW_9 = 9;
constexpr int32_t KERNEL_HW_16 = 16;
constexpr uint32_t NUM_HALF = 2;
constexpr uint32_t BASIC_BLOCK_SIZE_256 = 256;
constexpr uint32_t BASIC_BLOCK_SIZE_128 = 128;
constexpr uint32_t BASIC_BLOCK_SIZE_64 = 64;
constexpr uint32_t SPLIT_M_K = 1;
constexpr uint32_t SPLIT_N_K = 2;
constexpr uint32_t SPLIT_M_N = 3;
constexpr uint32_t L1_DEPTH_16 = 16;
constexpr uint32_t L1_DEPTH_8 = 8;
constexpr uint32_t L1_DEPTH_4 = 4;
constexpr uint32_t L1_DEPTH_2 = 2;
constexpr uint32_t STEP_2 = 2;

// dx only
constexpr size_t OUTPUT_PADDING_DIM = 5;
constexpr int32_t BYTE_BLOCK = 32;
constexpr uint32_t B16_BITS = 4;
constexpr uint32_t FP32_BITS = 3;
constexpr uint32_t FP32_DATA_SIZE = 4;
constexpr uint32_t F16_DATA_SIZE = 2; //Shared by BF16 and FP16
constexpr uint32_t NUM_FP32_C1OUT = 2;
constexpr int32_t FMAP_H_NUM = 2;

const size_t Y_INDEX = 0;
const size_t FILTER_INDEX = 1;
const size_t OUTPUT_BP_INDEX = 2;
const size_t BAIS_INDEX = 3;
const size_t OFFSET_W_INDEX = 4;
const size_t OUTPUT_PADDING_INDEX = 5;
const size_t OFFSET_X_INDEX = 6;

const int32_t DIM_LOW = 1;
const int32_t PAD_DIM_LOW = 0;
const int32_t PAD_DIM_UP = 255;
const int32_t STRIDES_DIM_HW_UP = 63;
const int32_t STRIDES_DIM_DEPTH_UP = 255;
const int32_t GROUPS_LOW = 1;
const int32_t GROUPS_UP = 65535;
const int32_t K_START_POSITION_MAX = 65535;
const int32_t K_NUM_TWO = 2;

#define OP_TILING_CHECK(cond, log_func, expr) \
  do {                                        \
    if (cond) {                               \
      log_func;                               \
      expr;                                   \
    }                                         \
  } while (0)


#define D_OP_LOGE(opname, fmt, ...) OpLogSub(OP, DLOG_ERROR, opname, fmt, ##__VA_ARGS__)
#define OP_LOGE_WITHOUT_REPORT(opname, ...) D_OP_LOGE(get_op_info(opname), __VA_ARGS__)

#define CUBE_INNER_ERR_REPORT(op_name, err_msg, ...)                                                 \
  do {                                                                                               \
    OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__);                                         \
    REPORT_INNER_ERROR("E69999", "op[%s], " err_msg, get_cstr(get_op_info(op_name)), ##__VA_ARGS__); \
  } while (0)


/**
 * if y is 0, return x
 */
template <typename T>
typename std::enable_if <std::is_signed<T>::value, T>::type CeilDiv(T x, T y) {
  if (y != 0 && x != 0) {
    const T quotient = x / y;
    return (x % y != 0 && ((x > 0) == (y > 0))) ? (quotient + 1) : quotient;
  }

  return x;
}

/**
 * if y is 0, return x
 */
template <typename T>
typename std::enable_if <std::is_unsigned<T>::value, T>::type CeilDiv(T x, T y) {
  if (y != 0 && x != 0) {
    const T quotient = x / y;
    return (x % y != 0) ? (quotient + 1) : quotient;
  }

  return x;
}

template <typename T>
inline T CeilDivision(T num1, T num2) {
    return CeilDiv(num1, num2);
}

inline int64_t CeilDivision(int64_t num1, int32_t num2) {
    return CeilDiv(num1, static_cast<int64_t>(num2));
}

template <typename T>
inline T Min(T num1, T num2) {
    return std::min(num1, num2);
}

inline int32_t Min(int64_t num1, int32_t num2) {
  return static_cast<int32_t>(std::min(num1, static_cast<int64_t>(num2)));
}

inline int32_t Min(int32_t num1, int64_t num2) {
  return static_cast<int32_t>(std::min(num2, static_cast<int64_t>(num1)));
}

inline int32_t CalcHi(int32_t ho, int32_t strideH, int32_t kernelHDilation, int32_t oriHi) {
    return Min(static_cast<int64_t>(ho - 1) * strideH + kernelHDilation, oriHi);
}

inline int32_t CalcHo(int64_t k, int32_t wo) {
    if (k == 0 || wo == 0) {
        return 0;
    }
// The complete K is ho*wo, k may exceed int32, but wo after the following division cannot exceed int32
    int32_t ho = static_cast<int32_t>(CeilDivision(k, wo));
    if (k % wo == 0 || wo % k == 0) {
        return ho;
    } else {
        return ho + 1;
    }
}

inline int32_t GetGcd(int32_t param1, int32_t param2) {
    // get greatest common divisor of param1 and param2
    if (param1 < param2) {
        std::swap(param1, param2);
    }
    if (param2 == 0) {
        return 0;
    }
    if (param1 % param2 == 0) {
        return param2;
    } else {
        return GetGcd(param2, param1 - param2);
    }
}

inline void GetFactors(std::vector<int32_t> &factorList, int64_t srcNum, int32_t maxFactor) {
    int32_t max_num = Min(srcNum, maxFactor);
    for (int32_t factor = 1; factor <= max_num; factor++) {
        if (srcNum % factor == 0) {
            factorList.push_back(factor);
        }
    }
}   

template <typename T>
inline typename std::enable_if <std::is_integral<T>::value, T>::type CeilAlign(T x, T align) {
    return CeilDiv(x, align) * align;
}

inline int64_t Lcm(int64_t param1, int64_t param2) {
  int64_t pram1Lcm = param1;
  int64_t pram2Lcm = param2;
  int64_t temp = pram1Lcm * pram2Lcm;
  int64_t param1Temp = pram1Lcm;
  while (pram1Lcm % pram2Lcm != 0) {
    param1Temp = pram1Lcm;
    pram1Lcm = pram2Lcm;
    pram2Lcm = param1Temp % pram2Lcm;
  }
  return temp / pram2Lcm;
}

inline int64_t Lcm(int32_t param1, int32_t param2) {
  return Lcm(static_cast<int64_t>(param1), static_cast<int64_t>(param2));
}

inline int64_t Lcm(int64_t param1, int32_t param2) {
  return Lcm(param1, static_cast<int64_t>(param2));
}

/**
 * if align is 0, return 0
 */
template <typename T>
typename std::enable_if <std::is_integral<T>::value, T>::type FloorAlign(T x, T align) {
  return align == 0 ? 0 : x / align * align;
}

}
#endif // TILING_CONV_BACKPROP_CONV3D_BP_TILING_UTIL_H